import pandas as pd
import numpy as np
import sqlalchemy
from sqlalchemy import create_engine
import time
import matplotlib.pyplot as plt
%matplotlib inline
CONNECTION_STRING = 'mssql+pymssql://IVYuser:resuyvi@vita.ieor.columbia.edu'
# Gets the database connection
def get_connection():
engine = create_engine(CONNECTION_STRING)
return engine.connect()
# Query database and return results in dataframe
def query_dataframe(query, connection=None):
# date_col should be a list
if connection is None:
connection = get_connection()
res = pd.read_sql(query, connection)
return res
# Query database using external file and return results in dataframe
def query_dataframe_f(filename, connection=None):
if connection is None:
connection = get_connection()
with open(filename, 'r') as handle:
return pd.read_sql(handle.read(), connection)
def data_frame_to_sql(df, table_name):
engine = create_engine(CONNECTION_STRING)
df.to_sql(table_name, engine, if_exists='replace')
def get_stock_data():
query = '''
select Date,
case when sp.Date = co.AnnouncementDate then 1 else 0 end as IsAnnouncementDate,
case when sp.Date = co.ChangeDate then 1 else 0 end as IsChangeDate,
BidLow,
AskHigh,
OpenPrice,
ClosePrice,
OpenPrice * AdjustmentFactor2 as Adj_OpenPrice,
ClosePrice * AdjustmentFactor2 as Adj_ClosePrice,
Volume,
TotalReturn,
AdjustmentFactor,
AdjustmentFactor2,
inSecurityID as SecurityID,
inTicker as Ticker,
inName as Name,
inSector as Sector,
AnnouncementDate,
ChangeDate,
co.ID as DataID
from XFDATA.dbo.SECURITY_PRICE sp
join XF.db_datawriter.hi2179_SP500_comp co on sp.SecurityID = co.inSecurityID
and abs(datediff(day,sp.Date,co.AnnouncementDate)) <= 35
and abs(datediff(day,sp.Date,co.ChangeDate)) <= 35
'''
return query_dataframe(query)
data = get_stock_data()
for data_id, group in data.groupby('DataID'):
announcement_date = group.AnnouncementDate.values[0]
announcement_date_str = pd.to_datetime(announcement_date).strftime('%Y%m%d')
change_date = group.ChangeDate.values[0]
change_date_str = pd.to_datetime(change_date).strftime('%Y%m%d')
name = group.Name.values[0]
sector = group.Sector.values[0]
fig, ax1 = plt.subplots(figsize=(18,5))
ax2 = ax1.twinx()
ax1.plot(group.Date, group.Adj_ClosePrice, 'b')
ax2.plot(group.Date, group.Volume, '--k')
plt.axvline(x=announcement_date,color='r',ls='dashed')
plt.axvline(x=change_date,color='b',ls='dashed')
ax1.legend(loc=2),ax2.legend(loc=1)
plt.grid(True)
plt.title(str(data_id) + ' ' + name + ' ' + ' A:' + announcement_date_str + ' C:' + change_date_str + ' ' + sector)
plt.show()